'''
Author: SlytherinGe
LastEditTime: 2021-05-21 11:43:49
'''
import mmcv
import matplotlib.pyplot as plt
import numpy as np

# if __name__ == '__main__':
#     feature = mmcv.load('/media/gejunyao/Disk/Gejunyao/exp_results/visualization/middle_part/data.pkl')
#     plt.figure()
#     for i, feat in enumerate(feature):
#         if i%2 == 0:
#             afeat = np.array(feat.cpu()).squeeze(0).squeeze(0)
#             plt.subplot(4,2,i//2+1)
#             plt.imshow(afeat)
#             plt.colorbar()
#             plt.axis('off')
#             plt.title('att at {}'.format(feat.shape))
#         else:
#             afeat = np.array(feat.cpu()).squeeze(0).squeeze(0)
#             plt.subplot(4,2,i//2+5)
#             plt.imshow(afeat)
#             plt.colorbar()
#             plt.axis('off')
#             plt.title('feature {}'.format(feat.shape))            
#     plt.show()

if __name__ == '__main__':
    feature = mmcv.load('/media/gejunyao/Disk/Gejunyao/exp_results/visualization/middle_part/hrfpn_data.pkl')
    att = []
    featmap = []
    plt.figure()
    for i, feat in enumerate(feature):
        if i%2 == 0:
            afeat = np.array(feat.cpu()).squeeze(0).squeeze(0)
            plt.subplot(5,3,i//2*3+1)
            plt.imshow(afeat)
            plt.colorbar()
            plt.axis('off')
            plt.title('att at size {}'.format(afeat.shape))
            att.append(afeat)
        else:
            afeat = np.array(feat.cpu()).squeeze(0).squeeze(0)
            plt.subplot(5,3,i//2*3+2)
            plt.imshow(afeat)
            plt.colorbar()
            plt.axis('off')
            plt.title('feature size {}'.format(afeat.shape))  
            featmap.append(afeat)          
    for i in range(len(att)):
        weighted_feat = att[i]*featmap[i]
        plt.subplot(5,3,i*3+3)
        plt.imshow(weighted_feat)
        plt.colorbar(fraction=0.05, pad=0.05)
        plt.title('weighted feature at size {}'.format(weighted_feat.shape))
        plt.axis('off')
    plt.show()